{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting C:\\Users\\zdwxx\\Downloads\\Compressed\\MNIST_data\\train-images-idx3-ubyte.gz\n",
      "Extracting C:\\Users\\zdwxx\\Downloads\\Compressed\\MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "Extracting C:\\Users\\zdwxx\\Downloads\\Compressed\\MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting C:\\Users\\zdwxx\\Downloads\\Compressed\\MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From <ipython-input-2-814daf426aa9>:20: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See tf.nn.softmax_cross_entropy_with_logits_v2.\n",
      "\n",
      "WARNING:tensorflow:From <ipython-input-2-814daf426aa9>:30: arg_max (from tensorflow.python.ops.gen_math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `argmax` instead\n",
      "第 0 个周期 准确率是 0.825\n",
      "第 1 个周期 准确率是 0.8924\n",
      "第 2 个周期 准确率是 0.9006\n",
      "第 3 个周期 准确率是 0.9049\n",
      "第 4 个周期 准确率是 0.909\n",
      "第 5 个周期 准确率是 0.91\n",
      "第 6 个周期 准确率是 0.9119\n",
      "第 7 个周期 准确率是 0.9131\n",
      "第 8 个周期 准确率是 0.9151\n",
      "第 9 个周期 准确率是 0.9156\n",
      "第 10 个周期 准确率是 0.9168\n",
      "第 11 个周期 准确率是 0.9182\n",
      "第 12 个周期 准确率是 0.9191\n",
      "第 13 个周期 准确率是 0.9194\n",
      "第 14 个周期 准确率是 0.9205\n",
      "第 15 个周期 准确率是 0.9206\n",
      "第 16 个周期 准确率是 0.9216\n",
      "第 17 个周期 准确率是 0.9212\n",
      "第 18 个周期 准确率是 0.9211\n",
      "第 19 个周期 准确率是 0.9218\n",
      "第 20 个周期 准确率是 0.9218\n"
     ]
    }
   ],
   "source": [
    "# 载入数据集\n",
    "mnist = input_data.read_data_sets(r\"C:\\Users\\zdwxx\\Downloads\\Compressed\\MNIST_data\", one_hot=True)\n",
    "\n",
    "# 定义每个批次的大小\n",
    "batch_size = 100\n",
    "# 计算一共有多少个批次\n",
    "n_batch = mnist.train.num_examples // batch_size\n",
    "\n",
    "# 定义两个placeholder\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])\n",
    "\n",
    "#创建一个简单的神经网络\n",
    "W = tf.Variable(tf.zeros([784, 10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 预测值\n",
    "\n",
    "# 二次代价函数\n",
    "\n",
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))\n",
    "\n",
    "# loss = tf.reduce_mean(tf.square(y - prediction))\n",
    "# 梯度下降法\n",
    "train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)\n",
    "\n",
    "# 初始化变量\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "# 结果存放在一个布尔型列表中\n",
    "correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(prediction, 1)) #arg_max返回1维张量中最大的值所在的位置\n",
    "# 求准确率\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#cast转换类型，True->1.0, False->0.0\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    \n",
    "    sess.run(init)\n",
    "    for epoch in range(21):\n",
    "        for batch in range(n_batch):\n",
    "            batch_xs, batch_ys = mnist.train.next_batch(batch_size) #类似于read，一次读取100张图片\n",
    "            sess.run(train_step, feed_dict={x : batch_xs, y : batch_ys})\n",
    "            \n",
    "        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})\n",
    "        print(\"第\", epoch, \"个周期\", \"准确率是\", acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
